import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision.utils import save_image


import torch


def generate_gaussian_masks(B, C, H, W, num_bins, device, base_sigma=0.05):
    """
    Generate Gaussian masks for each channel with peaks at different distances.
    Distances are denser near the center, and sigma is smaller for smaller distances.

    Args:
        B (int): Batch size.
        C (int): Number of channels (4 in this case).
        H (int): Image height.
        W (int): Image width.
        num_bins (int): Number of masks per channel (5 in this case).
        device: Torch device (e.g., 'cuda' or 'cpu').
        base_sigma (float): Base standard deviation for Gaussian (controls spread).

    Returns:
        masks: Tensor of shape (B, C, num_bins, H, W) containing the masks.
    """
    # Create a grid of coordinates
    y, x = torch.meshgrid(
        torch.linspace(-1, 1, H), torch.linspace(-1, 1, W), indexing="ij"
    )
    x = x.to(device).unsqueeze(0).expand(B, -1, -1)  # (B, H, W)
    y = y.to(device).unsqueeze(0).expand(B, -1, -1)  # (B, H, W)
    dist = torch.sqrt(x**2 + y**2)  # Distance from center (B, H, W)

    masks = []
    image_size = min(H, W)  # Use smaller dimension for scaling distances

    # Non-linear distance distribution (denser near center)
    # Use quadratic or exponential spacing for distances
    t = torch.linspace(0, 1, num_bins) ** 2  # Quadratic spacing
    max_dist = math.sqrt(2) / 2  # Maximum normalized distance (adjust as needed)
    distances = t * max_dist  # Normalized distances, denser near 0
    # print(distances)

    for _ in range(C):  # For each channel
        channel_masks = []
        for i in range(num_bins):
            # Dynamic sigma: smaller for smaller distances
            sigma = base_sigma * (
                1 + 10 * distances[i] / max_dist
            )  # Scale sigma with distance
            if i == 0:
                # Centered Gaussian: peak at (0, 0)
                mask = torch.exp(-(dist**2) / (2 * sigma**2))
            else:
                # Offset Gaussian: peak at distances[i] * 2 (ring-shaped)
                target_dist = distances[i] * 2  # Normalize to [-1, 1] range
                mask = torch.exp(-((dist - target_dist) ** 2) / (2 * sigma**2))

            channel_masks.append(mask.unsqueeze(1))  # (B, 1, H, W)

        channel_masks = torch.cat(channel_masks, dim=1)  # (B, num_bins, H, W)
        masks.append(channel_masks.unsqueeze(1))  # (B, 1, num_bins, H, W)

    masks = torch.cat(masks, dim=1)  # (B, C, num_bins, H, W)
    return masks


class AvgPool2d(nn.Module):
    def __init__(
        self,
        kernel_size=None,
        base_size=None,
        auto_pad=True,
        fast_imp=False,
        train_size=None,
    ):
        super().__init__()
        self.kernel_size = kernel_size
        self.base_size = base_size
        self.auto_pad = auto_pad

        # only used for fast implementation
        self.fast_imp = fast_imp
        self.rs = [5, 4, 3, 2, 1]
        self.max_r1 = self.rs[0]
        self.max_r2 = self.rs[0]
        self.train_size = train_size

    def extra_repr(self) -> str:
        return "kernel_size={}, base_size={}, stride={}, fast_imp={}".format(
            self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
        )

    def forward(self, x):
        if self.kernel_size is None and self.base_size:
            train_size = self.train_size
            if isinstance(self.base_size, int):
                self.base_size = (self.base_size, self.base_size)
            self.kernel_size = list(self.base_size)
            self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
            self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]

            # only used for fast implementation
            self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
            self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])

        if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
            return F.adaptive_avg_pool2d(x, 1)

        if self.fast_imp:  # Non-equivalent implementation but faster
            h, w = x.shape[2:]
            if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
                out = F.adaptive_avg_pool2d(x, 1)
            else:
                r1 = [r for r in self.rs if h % r == 0][0]
                r2 = [r for r in self.rs if w % r == 0][0]
                # reduction_constraint
                r1 = min(self.max_r1, r1)
                r2 = min(self.max_r2, r2)
                s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
                n, c, h, w = s.shape
                k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(
                    w - 1, self.kernel_size[1] // r2
                )
                out = (
                    s[:, :, :-k1, :-k2]
                    - s[:, :, :-k1, k2:]
                    - s[:, :, k1:, :-k2]
                    + s[:, :, k1:, k2:]
                ) / (k1 * k2)
                out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
        else:
            n, c, h, w = x.shape
            s = x.cumsum(dim=-1).cumsum_(dim=-2)
            s = torch.nn.functional.pad(s, (1, 0, 1, 0))  # pad 0 for convenience
            k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
            s1, s2, s3, s4 = (
                s[:, :, :-k1, :-k2],
                s[:, :, :-k1, k2:],
                s[:, :, k1:, :-k2],
                s[:, :, k1:, k2:],
            )
            out = s4 + s1 - s2 - s3
            out = out / (k1 * k2)

        if self.auto_pad:
            n, c, h, w = x.shape
            _h, _w = out.shape[2:]
            # print(x.shape, self.kernel_size)
            pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
            out = torch.nn.functional.pad(out, pad2d, mode="replicate")

        return out


def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
    for n, m in model.named_children():
        if len(list(m.children())) > 0:
            ## compound module, go inside it
            replace_layers(m, base_size, train_size, fast_imp, **kwargs)

        if isinstance(m, nn.AdaptiveAvgPool2d):
            pool = AvgPool2d(
                base_size=base_size, fast_imp=fast_imp, train_size=train_size
            )
            assert m.output_size == 1
            setattr(model, n, pool)



class LayerNormFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, weight, bias, eps):
        ctx.eps = eps
        N, C, H, W = x.size()
        mu = x.mean(1, keepdim=True)
        var = (x - mu).pow(2).mean(1, keepdim=True)
        y = (x - mu) / (var + eps).sqrt()
        ctx.save_for_backward(y, var, weight)
        y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps

        N, C, H, W = grad_output.size()
        y, var, weight = ctx.saved_variables
        g = grad_output * weight.view(1, C, 1, 1)
        mean_g = g.mean(dim=1, keepdim=True)

        mean_gy = (g * y).mean(dim=1, keepdim=True)
        gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
        return (
            gx,
            (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
            grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
            None,
        )


class LayerNorm2d(nn.Module):

    def __init__(self, channels, eps=1e-6):
        super(LayerNorm2d, self).__init__()
        self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
        self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
        self.eps = eps

    def forward(self, x):
        return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)


class SimpleGate(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2


class NAFBlock(nn.Module):
    def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0):
        super().__init__()
        dw_channel = c * DW_Expand
        self.conv1 = nn.Conv2d(
            in_channels=c,
            out_channels=dw_channel,
            kernel_size=1,
            padding=0,
            stride=1,
            groups=1,
            bias=True,
        )
        self.conv2 = nn.Conv2d(
            in_channels=dw_channel,
            out_channels=dw_channel,
            kernel_size=3,
            padding=1,
            stride=1,
            groups=dw_channel,
            bias=True,
        )
        self.conv3 = nn.Conv2d(
            in_channels=dw_channel // 2,
            out_channels=c,
            kernel_size=1,
            padding=0,
            stride=1,
            groups=1,
            bias=True,
        )

        # Simplified Channel Attention
        self.sca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(
                in_channels=dw_channel // 2,
                out_channels=dw_channel // 2,
                kernel_size=1,
                padding=0,
                stride=1,
                groups=1,
                bias=True,
            ),
        )

        # SimpleGate
        self.sg = SimpleGate()

        ffn_channel = FFN_Expand * c
        self.conv4 = nn.Conv2d(
            in_channels=c,
            out_channels=ffn_channel,
            kernel_size=1,
            padding=0,
            stride=1,
            groups=1,
            bias=True,
        )
        self.conv5 = nn.Conv2d(
            in_channels=ffn_channel // 2,
            out_channels=c,
            kernel_size=1,
            padding=0,
            stride=1,
            groups=1,
            bias=True,
        )

        self.norm1 = LayerNorm2d(c)
        self.norm2 = LayerNorm2d(c)

        self.dropout1 = (
            nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity()
        )
        self.dropout2 = (
            nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity()
        )

        self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)

    def forward(self, inp):
        x = inp

        x = self.norm1(x)

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.sg(x)
        x = x * self.sca(x)
        x = self.conv3(x)

        x = self.dropout1(x)

        y = inp + x * self.beta

        x = self.conv4(self.norm2(y))
        x = self.sg(x)
        x = self.conv5(x)

        x = self.dropout2(x)

        return y + x * self.gamma


class CFMG(nn.Module):

    def __init__(
        self,
        img_channel=7,
        output_channel=3,
        width=32,
        middle_blk_num=1,
        enc_blk_nums=[1, 1, 1, 1],
        num_bins=25,
    ):
        super().__init__()

        self.num_bins = num_bins

        self.intro = nn.Conv2d(
            in_channels=img_channel,
            out_channels=width,
            kernel_size=3,
            padding=1,
            stride=1,
            groups=1,
            bias=True,
        )
        self.ending = nn.Conv2d(
            in_channels=width,
            out_channels=output_channel,
            kernel_size=3,
            padding=1,
            stride=1,
            groups=1,
            bias=True,
        )

        self.output_channel = output_channel

        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        self.middle_blks = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()

        chan = width
        for num in enc_blk_nums:
            self.encoders.append(nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))
            self.downs.append(nn.Conv2d(chan, 2 * chan, 2, 2))
            chan = chan * 2

        self.middle_blks = nn.Sequential(
            *[NAFBlock(chan) for _ in range(middle_blk_num)]
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.soft = nn.Softmax(dim=-1)
        self.fclayer_1 = nn.Linear(chan, self.output_channel * self.num_bins).cuda()
        self.dropout_1 = nn.Dropout(p=0.2)  # Added dropout after fclayer_1
        self.fclayer_last = nn.Linear(self.output_channel * self.num_bins, self.output_channel * self.num_bins).cuda()
        self.dropout_2 = nn.Dropout(p=0.2)  # Added dropout after fclayer_last

        self.padder_size = 2 ** len(self.encoders)

        # Register buffers for ring mask and its dimensions
        self.register_buffer("ring_mask", None)
        self.register_buffer("mask_height", torch.tensor(0, dtype=torch.long))
        self.register_buffer("mask_width", torch.tensor(0, dtype=torch.long))

 


    def forward(self, inp, c=0.5, inp_ref=None):
        B, C, H, W = inp.shape
        inp = self.check_image_size(inp)

        inp_fq = torch.fft.fftn(inp, dim=(-1, -2))
        inp_fq = torch.fft.fftshift(inp_fq)
        inp_fq = torch.log10(torch.abs(inp_fq) + 1)

        if not torch.is_tensor(c):
            c_tensor = torch.full(
                (B, 1, inp.size(2), inp.size(3)), c, device=inp.device, dtype=inp.dtype
            )
        else:
            c_tensor = c

        x = self.intro(torch.cat([inp, inp_fq, c_tensor], dim=1))

        # encs = []

        for encoder, down in zip(self.encoders, self.downs):
            x = encoder(x)
            # encs.append(x)
            x = down(x)

        x = self.middle_blks(x)
        x = self.avgpool(x).view(B, -1)

        num_bins = self.num_bins
        num_channels = self.output_channel


        x = self.dropout_1(x) 
        x = self.fclayer_1(x)
        x = self.dropout_2(x) 
        value_prob = self.fclayer_last(x)

        value_prob = value_prob.view(B, num_channels, num_bins)  # (B, 4, num_bins)
        value_prob = F.softmax(value_prob, dim=1)

     
        masks = generate_gaussian_masks(
            B, num_channels, H, W, num_bins, device=inp.device
        )  # (B, 4, num_bins, H, W)

        # Weight masks by value_prob
        value_prob = value_prob.view(
            B, num_channels, num_bins, 1, 1
        )  # (B, 4, num_bins, 1, 1)

        weighted_masks = masks * value_prob  # (B, 4, num_bins, H, W)

 
        channel_images = weighted_masks.sum(dim=2)  # (B, 4, H, W)

        final_image = F.softmax(channel_images, dim=1)  # (B, 4, H, W)

      
        return final_image, value_prob.squeeze(-1).squeeze(-1)

    def check_image_size(self, x):
        _, _, h, w = x.size()
        mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
        mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
        return x

    def save_model(self, outf):
        state_dict = self.state_dict()
        torch.save(state_dict, outf)

